import torch
import torch.nn as nn
import copy
import dgl
import numpy as np
import json
from typing import List, Union
import torch
import torch.nn as nn
import scipy.sparse as sp
from scipy.sparse import coo_matrix
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import Dataset, DataLoader
import math
import networkx as nx
from networkx.algorithms.community import greedy_modularity_communities
import community as community_louvain
from generative_model.generative_model import CDRVAE

def filter_edges(
        user_ids: torch.Tensor,
        item_ids: torch.Tensor,
        source_item_ids: torch.Tensor,
        community_detection
):
    """
    对输入的 user-item 二部图做社区检测，找出跨社区边；
    仅当跨社区边对应的 item_id 在 source_item_ids 中，才删除此边。
    最后返回删除后的 (user_ids, item_ids).

    参数：
    -------
    user_ids : [N] 一维张量，每个元素是用户ID
    item_ids : [N] 一维张量，每个元素是物品ID (与 user_ids 对应)
    source_item_ids : [K] 一维张量，需要筛查的物品ID列表

    返回：
    -------
    final_user_ids : [M] 删除跨社区边后保留下来的用户ID
    final_item_ids : [M] 对应的物品ID
    """
    # ===== 1. 转成 Python 列表，方便后续处理 =====
    user_list = user_ids.tolist()
    item_list = item_ids.tolist()
    print(len(user_list))
    print(len(item_list))
    source_items_set = set(source_item_ids.tolist())  # 转为集合以便快速 membership 检查
    # ===== 2. 构建二部图并做社区检测 =====
    G = nx.Graph()
    # 2.1 给用户、物品节点加前缀，并标注 bipartite 属性
    unique_users = set(user_list)
    unique_items = set(item_list)
    G.add_nodes_from((f"u_{u}", {"bipartite": 0}) for u in unique_users)
    G.add_nodes_from((f"i_{i}", {"bipartite": 1}) for i in unique_items)

    # 2.2 一次性添加边 (user_node, item_node)
    user_nodes = [f"u_{u}" for u in user_list]
    item_nodes = [f"i_{i}" for i in item_list]
    G.add_edges_from(zip(user_nodes, item_nodes))

    # 2.3 Louvain 找社区
    if community_detection=='louvain':
       partition = community_louvain.best_partition(G)
    elif community_detection=='greedy_modularity_communities':
       partition = greedy_modularity_communities(G)
    # partition: {node_name: community_id}

    # ===== 3. 找到所有“跨社区”边 =====
    #    跨社区边： partition[n1] != partition[n2]
    cross_edges = [
        (n1, n2)
        for (n1, n2) in G.edges()
        if partition[n1] != partition[n2]
    ]
    print('跨社群的边',len(cross_edges))
    # ===== 4. 在“跨社区”的边中，筛选出 item 在 source_items_set 的边 =====
    #    只有这些边才需要删除
    cross_community_edges_to_remove = []
    for (n1, n2) in cross_edges:
        # 判断哪端是物品节点
        if n1.startswith("i_"):
            item_id = int(n1[2:])  # i_{num}
        elif n2.startswith("i_"):
            item_id = int(n2[2:])
        else:
            # 如果既没遇到 i_ 前缀，理论上是 (u_, u_)，但二部图中不应出现
            continue

        if item_id in source_items_set:
            cross_community_edges_to_remove.append((n1, n2))

    # 将跨社区 + 在 source_item_ids 范围内的边转换回 (user_id, item_id) 对
    edges_to_remove_pairs = []
    for (n1, n2) in cross_community_edges_to_remove:
        if n1.startswith("u_") and n2.startswith("i_"):
            u_id = int(n1[2:])
            i_id = int(n2[2:])
        elif n2.startswith("u_") and n1.startswith("i_"):
            u_id = int(n2[2:])
            i_id = int(n1[2:])
        else:
            continue
        edges_to_remove_pairs.append((u_id, i_id))

    # ===== 5. 从原始 (user_list, item_list) 中删除上述边 =====
    #    先做成集合以方便去重 + O(1) membership
    edges_to_remove_set = set(edges_to_remove_pairs)

    original_pairs = list(zip(user_list, item_list))
    final_pairs = [p for p in original_pairs if p not in edges_to_remove_set]

    # 分离出 final_user_ids, final_item_ids
    final_user_ids = torch.tensor([p[0] for p in final_pairs], dtype=user_ids.dtype)
    final_item_ids = torch.tensor([p[1] for p in final_pairs], dtype=item_ids.dtype)
    print(final_user_ids.shape)
    return final_user_ids, final_item_ids



class OverlapUserDataset(Dataset):
    def __init__(self, data):
        # data 是一个 (N, M) 的 Tensor
        self.data = data
    def __len__(self):
        return self.data.size(0)  # N
    def __getitem__(self, idx):
        return self.data[idx]


class CDR_Reproduce(nn.Module):
    def __init__(self,
        config,
        source_u,
        source_i,
        target_u,
        target_i,
        total_num_users,
        total_num_items,
        overlapped_num_users,
        source_num_users,
        source_num_items,
        target_num_users,
        target_num_items
                 ):
        super().__init__()
        self.config = config
        self.device = config['device']
        self.source_u = source_u
        self.source_i = source_i
        self.target_u = target_u
        self.target_i = target_i
        self.total_num_users=total_num_users
        self.total_num_items=total_num_items
        self.overlapped_num_users=overlapped_num_users
        self.source_num_users=source_num_users
        self.source_num_items=source_num_items
        self.target_num_users=target_num_users
        self.target_num_items=target_num_items
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users,
                                                 embedding_dim=config['embedding_size'])
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items,
                                                 embedding_dim=config['embedding_size'])
        self.n_layers=config['generate_layers']
        self.merge_norm_adj_matrix = \
            (self.get_norm_adj_mat_a
            (
            self.get_sparse_matrix(self.total_num_users,
                                   self.total_num_items,
                                   self.source_u,
                                   self.source_i).astype(np.float32),
            self.get_sparse_matrix(self.total_num_users,
                                   self.total_num_items,
                                   self.target_u,
                                   self.target_i).astype(np.float32))
            )
        # mask 多次。提取子图mask来进行
        self.generater=CDRVAE(self.config,self.total_num_users,self.total_num_items)
        self.Dataset=OverlapUserDataset(self.split_overlap_users_n_groups(config['mask_num_overlap_users']))
        # 每次的个batch是一一些想要mask掉的user的id
        self.OverlapUserDataloader=DataLoader(self.Dataset,
                                              batch_size=config['train_generate_model_batch_size'],
                                              shuffle=True,
                                              num_workers=0)
    def split_overlap_users_n_groups(self, mask_num_overlap_users):
        # 输出是N*batch_size的需要mask掉的users的id，以及
        # 创建并打乱 0 到 x 的序列
        x=self.overlapped_num_users
        arr = torch.arange(x + 1)[torch.randperm(x + 1)]

        # 计算 reshape 所需的元素总数(向上取整到 n 的倍数)
        total_elements = (x + 1)
        rows = math.ceil((x+1) / mask_num_overlap_users)
        needed = rows * mask_num_overlap_users

        # 如果不足，则填充随机数来凑够 needed 个元素
        if needed > total_elements:
            pad_size = needed - total_elements
            # 在 [0, x] 范围内随机填充 pad_size 个数
            pads = torch.randint(low=0, high=x + 1, size=(pad_size,))
            arr = torch.cat([arr, pads])
        arr = arr.reshape(-1,mask_num_overlap_users)
        return arr

    def _random_mask_edges(self,tensor: torch.Tensor, n: int) -> torch.Tensor:
        """
        随机选择张量中 `n` 个 `False` 位置并将其设为 `True`，保持原来 `True` 的位置不变。
        参数：
        -------
        tensor : torch.Tensor
            一个布尔类型的张量。
        n : int
            需要设为 `True` 的 `False` 位置的数量。
        返回：
        torch.Tensor
            修改后的张量，其中随机选择的 `n` 个 `False` 位置被设为 `True`。
        """
        # 找到所有为 False 的位置的扁平索引
        false_indices = torch.nonzero(~tensor, as_tuple=False).squeeze()
        num_false = false_indices.numel()
        # 随机选择 n 个 False 的扁平索引
        selected_indices = false_indices[torch.randperm(num_false)[:n]]
        tensor[selected_indices] = True
        return tensor


    # 产生一组mask掉的matrix和对应的labels
    def mask_interactions(self,masked_over_u_idx):
        # 得到一个在source domain上mask掉的graph，以及需要预测的mask的边
        # num_mask=int(self.mask_ratio*self.overlapped_num_users)
        # masked_over_u_idx = torch.randint(
        #         0, self.overlapped_num_users+1, (num_mask,))
        # 随机挑选一些overlap users给mask掉
        # get the labels
        # 找到mask掉的users在interactions当中的位置
        mask = torch.isin(self.source_u,masked_over_u_idx)
        if self.config['random_mask']!=0:
            mask=self._random_mask_edges(mask,self.config['random_mask'])
        label_u_interactions=self.source_u[mask]
        label_i_interactions=self.source_i[mask]
        un_masked_source_u_interactions=self.source_u[~mask]
        un_masked_source_i_interactions=self.source_i[~mask]
        # 将被mask掉的source domain的graph和target domain的graph合并
        masked_norm_adj_matrix= self.get_norm_adj_mat_a(
            self.get_sparse_matrix(self.total_num_users, self.total_num_items, un_masked_source_u_interactions, un_masked_source_i_interactions).astype(np.float32),
            self.get_sparse_matrix(self.total_num_users, self.total_num_items, self.target_u, self.target_i).astype(np.float32))
        return label_u_interactions.to(self.device),label_i_interactions.to(self.device),masked_norm_adj_matrix.to(self.device)



    def get_norm_adj_mat_a(self, interaction_matrix_s, interaction_matrix_t):
        # build adj matrix
        A = sp.dok_matrix(
            (self.total_num_users + self.total_num_items, self.total_num_users + self.total_num_items), dtype=np.float32
        )
        inter_S = interaction_matrix_s
        inter_S_t = interaction_matrix_s.transpose()
        inter_T = interaction_matrix_t
        inter_T_t = interaction_matrix_t.transpose()
        data_dict = dict(
            zip(zip(inter_S.row, inter_S.col + self.total_num_users), [1] * inter_S.nnz)
        )
        data_dict.update(
            dict(
                zip(
                    zip(inter_S_t.row + self.total_num_users, inter_S_t.col),
                    [1] * inter_S_t.nnz,
                )
            )
        )
        data_dict.update(
            dict(
                zip(
                    zip(inter_T.row, inter_T.col + self.total_num_users),
                    [1] * inter_T.nnz,
                )
            )
        )
        data_dict.update(
            dict(
                zip(
                    zip(inter_T_t.row + self.total_num_users, inter_T_t.col),
                    [1] * inter_T_t.nnz,
                )
            )
        )
        A._update(data_dict)
        # norm adj matrix
        sumArr = (A > 0).sum(axis=1)
        # add epsilon to avoid divide by zero Warning
        diag = np.array(sumArr.flatten())[0] + 1e-7
        diag = np.power(diag, -0.5)
        D = sp.diags(diag)
        L = D * A * D
        # covert norm_adj matrix to tensor
        L = sp.coo_matrix(L)
        row = L.row
        col = L.col
        i = torch.LongTensor(np.array([row, col]))
        data = torch.FloatTensor(L.data)
        SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
        return SparseL.to('cuda')


    def get_sparse_matrix(self, user_num, item_num, src, tgt):
        data = np.ones(len(src))
        mat = coo_matrix((data, (src, tgt)), shape=(user_num, item_num))
        return mat

    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings
    def calculate_loss(self,mask_u_idx):
        # mask_u_idx: 想要在overlap user中mask掉的user的id，下面的label就是被mask掉的边
        label_u_interactions,label_i_interactions,masked_norm_adj_matrix=self.mask_interactions(mask_u_idx)
        all_embeddings=self.get_ego_embeddings()
        loss=self.generater.forward(all_embeddings,label_u_interactions, label_i_interactions, masked_norm_adj_matrix,self.merge_norm_adj_matrix,'train')
        return loss

    def _get_topk_items(self,
                               A: torch.Tensor,
                               B: torch.Tensor,
                               k: int,
                               chunk_size: int = 1024):
        """
        在内存友好的前提下，获取 (A x B^T) 每行的 top-k 索引。
        通过分块的方式计算，避免一次性生成 (a, b) 过大的结果张量。
        参数：
        -------
        A : (a, dim) 的张量
        B : (b, dim) 的张量
        k : 每行要获取的前 k 大元素
        chunk_size : 每次处理 B 的行数（列块大小），可根据内存情况调节

        返回：
        -------
        top_indices : (a, k) 的张量，每行表示在 B 里最相似的 k 个向量索引
        """
        device = A.device
        a, dimA = A.shape
        b, dimB = B.shape
        # 用于存储当前全局 top-k 的 (a, k) 分值和索引
        # 值初始化为 -inf，索引初始化为 -1
        global_top_vals = torch.full((a, k), float('-inf'), device=device)
        global_top_inds = torch.full((a, k), -1, dtype=torch.long, device=device)

        start = 0
        while start < b:
            end = min(start + chunk_size, b)
            # 取 B 的一部分
            B_chunk = B[start:end, :]  # (chunk_len, dim)
            # 计算局部结果 A x B_chunk^T => (a, chunk_len)
            partial_result = torch.matmul(A, B_chunk.transpose(0, 1))

            # 在每行取出局部 top-k
            partial_top_vals, partial_top_inds = partial_result.topk(k, dim=1)
            # partial_top_inds 是相对于 B_chunk 的局部索引，需要加上 start 才是 B 的全局索引
            partial_top_inds += start
            # 将局部 (vals, inds) 与全局 (vals, inds) 合并
            # 先在列维度 concat => (a, 2k)
            merged_vals = torch.cat([global_top_vals, partial_top_vals], dim=1)
            merged_inds = torch.cat([global_top_inds, partial_top_inds], dim=1)
            # 在 merged 上再做一次 topk，得到新的全局 topk
            merged_top_vals, merged_top_pos = merged_vals.topk(k, dim=1)
            # merged_top_pos 是在 [0..2k) 间的索引，要映射回 merged_inds
            row_idx = torch.arange(a, device=device).unsqueeze(1)
            final_inds = merged_inds[row_idx, merged_top_pos]
            # 更新全局候选
            global_top_vals = merged_top_vals
            global_top_inds = final_inds
            start = end
        return global_top_inds

    def generate_edges(self):
        all_embeddings=self.get_ego_embeddings()
        generating_user_id = torch.arange(self.target_num_users+1, self.total_num_users)
        u_embeddings, i_embeddings = self.generater.forward(all_embeddings,
                                                            generating_user_id,
                                                            1,
                                                            self.merge_norm_adj_matrix,
                                                            self.merge_norm_adj_matrix,
                                                            'generate')
        i_embeddings=i_embeddings[self.target_num_items+1:self.total_num_items]
        source_indexes=self._get_topk_items(u_embeddings,i_embeddings,self.config['generate_edges']).flatten()
        generated_user_id = torch.repeat_interleave(generating_user_id, repeats=self.config['generate_edges'])
        re_source_u=torch.cat((self.source_u, generated_user_id), dim=0)
        re_source_i=torch.cat((self.source_i, source_indexes), dim=0)
        # self.source_u,self.source_i=filter_edges(self.source_u,self.source_i,torch.unique(self.source_i))
        return re_source_u,re_source_i

    def data_reproduce(self):
        if self.config['add_edges']==True:
            self.source_u,self.source_i=self.generate_edges()
        if self.config['filter_edges']==True:
            self.source_u,self.source_i=filter_edges(self.source_u,
                                                     self.source_i,
                                                     torch.unique(self.source_i),
                                                     self.config['community_detection_method'])
        return self.source_u,self.source_i




def train(config,model):
   optimizer=torch.optim.Adam(model.parameters(), lr=config['generate_learning_rate'])
   for epoch in range(config['generate_model_epochs']):
       print(epoch)
       for batch_idx, mask_u_idx in enumerate(model.OverlapUserDataloader):
           model.train()
           optimizer.zero_grad()
           loss = model.calculate_loss(mask_u_idx)
           loss.backward()
           optimizer.step()
       print('loss:',loss)
